import pandas as pd

from model import Torchmodel
from config import Config
import torch
from loader import load_data
from tqdm import tqdm
import numpy as np
from main import logger
from scipy.stats import pearsonr
from transformers import AutoTokenizer


class Predict():
    def __init__(self, model_save_path, config, data=None):
        self.config = config
        self.model = Torchmodel(config)
        # self.model.load_state_dict(torch.load(model_save_path))
        self.model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
        self.test_data = data
        self.model.eval()
        self.tokenizer = AutoTokenizer.from_pretrained(config["pretrain_model_path"])

    def predict(self, inputs, target):
        inputs = self.tokenizer.encode(inputs, max_length=20, padding='max_length', truncation=True)
        target = self.tokenizer.encode(target, max_length=20, padding='max_length', truncation=True)
        inputs = torch.LongTensor([inputs])
        target = torch.LongTensor([target])
        with torch.no_grad():
            distance_score = self.model(inputs, target)  # 将测试数据进行编码，不输入labels，使用模型当前参数进行预测
        return distance_score

    def batch_predict(self, input, target):
        self.all_pearson_score = []
        for batch_data in tqdm(self.test_data):
            if torch.cuda.is_available():
                batch_data = [d.cuda() for d in batch_data]
            inputs, target, label = batch_data
            with torch.no_grad():
                distance_score = self.model(inputs, target)  # 将测试数据进行编码，不输入labels，使用模型当前参数进行预测
            batch_pearson_score = self.pearsSim(distance_score, label)
            self.all_pearson_score.append(batch_pearson_score)
        logger.info("epoch average pearson: %f" % np.mean(self.all_pearson_score))
        return np.mean(self.all_pearson_score)

    def pearsSim(self, tensor1, tensor2):
        array1 = np.array(tensor1.detach().cpu())
        array2 = np.array(tensor2.detach().cpu())
        array1, array2 = array1.squeeze(), array2.squeeze()
        pc = pearsonr(array1, array2)[0]
        return pc

    def save_predict_result(self):
        pass


if __name__ == '__main__':
    model_save_path = 'output/deberta-v3-small_epoch_50.pth'
    inference_model = Predict(model_save_path, Config)

    test_data_path = Config['test__data_path']
    test_data = pd.read_csv(test_data_path)
    test_data = test_data.sample(400, random_state=42)
    anchor, targets = test_data['input'].values, test_data['target'].values
    out_test_data = test_data[['id', 'anchor', 'target', 'score']].copy()
    out_pred_score = []
    for i in tqdm(range(len(test_data))):
        temp_a = anchor[i]
        temp_b = targets[i]
        pred_score = inference_model.predict(anchor[i], targets[i])
        # temp_e = float('%.3f' % pred_score)
        temp_e = '%.3f' % pred_score
        # temp_d = float(pred_score.numpy())
        out_pred_score.append(temp_e)
    out_test_data['pred_score'] = out_pred_score
    out_test_data.to_excel('test_data_predict.xlsx', index=False)
